{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f4b72fd7",
   "metadata": {},
   "source": [
    "## Description:\n",
    "这个是sharedBottom模型的demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "324f6ca1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:\n",
      "DeepCTR version 0.9.0 detected. Your version is 0.8.2.\n",
      "Use `pip install -U deepctr` to upgrade.Changelog: https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\ZHONGQ~1\\AppData\\Local\\Temp/ipykernel_6560/2989541817.py:22: The name tf.keras.backend.set_session is deprecated. Please use tf.compat.v1.keras.backend.set_session instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\ZHONGQ~1\\AppData\\Local\\Temp/ipykernel_6560/2989541817.py:22: The name tf.keras.backend.set_session is deprecated. Please use tf.compat.v1.keras.backend.set_session instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "from sklearn.metrics import roc_auc_score, mean_squared_error, mean_absolute_error\n",
    "from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat\n",
    "from deepctr.feature_column import get_feature_names\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "config = tf.compat.v1.ConfigProto()\n",
    "config.gpu_options.allow_growth = True      # TensorFlow按需分配显存\n",
    "config.gpu_options.per_process_gpu_memory_fraction = 0.5  # 指定显存分配比例\n",
    "tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "577efc09",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = '../data_process'\n",
    "data = pd.read_csv(os.path.join(data_path, 'train_data.csv'), index_col=0, parse_dates=['expo_time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "88a84ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 选择出需要用到的列\n",
    "use_cols = ['user_id', 'article_id', 'expo_time', 'net_status', 'exop_position', 'duration', 'device', 'city', 'age', 'gender', 'img_num', 'cat_1', 'click']\n",
    "data_new = data[use_cols]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ce2ea472",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 由于这个data_new的数据量还是太大， 我电脑训练不动， 所以这里再进行一波抽样\n",
    "users = set(data_new['user_id'])\n",
    "sampled_users = random.sample(users, 1000)\n",
    "data_new = data_new[data_new['user_id'].isin(sampled_users)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "84eb94aa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>article_id</th>\n",
       "      <th>expo_time</th>\n",
       "      <th>net_status</th>\n",
       "      <th>exop_position</th>\n",
       "      <th>duration</th>\n",
       "      <th>device</th>\n",
       "      <th>city</th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>img_num</th>\n",
       "      <th>cat_1</th>\n",
       "      <th>click</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10661</th>\n",
       "      <td>60</td>\n",
       "      <td>2174</td>\n",
       "      <td>2021-06-30 13:36:57</td>\n",
       "      <td>0</td>\n",
       "      <td>17</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>15</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10662</th>\n",
       "      <td>60</td>\n",
       "      <td>4458</td>\n",
       "      <td>2021-06-30 13:36:57</td>\n",
       "      <td>0</td>\n",
       "      <td>21</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.033149</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10663</th>\n",
       "      <td>60</td>\n",
       "      <td>4037</td>\n",
       "      <td>2021-06-30 13:40:23</td>\n",
       "      <td>0</td>\n",
       "      <td>24</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.033149</td>\n",
       "      <td>12</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10664</th>\n",
       "      <td>60</td>\n",
       "      <td>3109</td>\n",
       "      <td>2021-06-30 13:36:57</td>\n",
       "      <td>0</td>\n",
       "      <td>14</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.038674</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10665</th>\n",
       "      <td>60</td>\n",
       "      <td>14125</td>\n",
       "      <td>2021-07-03 06:10:46</td>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>174</td>\n",
       "      <td>237</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.027624</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       user_id  article_id           expo_time  net_status  exop_position  \\\n",
       "10661       60        2174 2021-06-30 13:36:57           0             17   \n",
       "10662       60        4458 2021-06-30 13:36:57           0             21   \n",
       "10663       60        4037 2021-06-30 13:40:23           0             24   \n",
       "10664       60        3109 2021-06-30 13:36:57           0             14   \n",
       "10665       60       14125 2021-07-03 06:10:46           0              7   \n",
       "\n",
       "       duration  device  city  age  gender   img_num  cat_1  click  \n",
       "10661         0     174   237    1       1  0.000000     15      0  \n",
       "10662         0     174   237    1       1  0.033149      1      0  \n",
       "10663         0     174   237    1       1  0.033149     12      0  \n",
       "10664         0     174   237    1       1  0.038674     13      0  \n",
       "10665         0     174   237    1       1  0.027624     13      0  "
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_new.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df0b628c",
   "metadata": {},
   "source": [
    "## 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e380f114",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 处理img_num\n",
    "def transform(x):\n",
    "    if x == '上海':\n",
    "        return 0\n",
    "    elif isinstance(x, float):\n",
    "        return float(x)\n",
    "    else:\n",
    "        return float(eval(x))\n",
    "data_new['img_num'] = data_new['img_num'].apply(lambda x: transform(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "46f77eb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_id_raw = data_new[['user_id']].drop_duplicates('user_id')\n",
    "doc_id_raw = data_new[['article_id']].drop_duplicates('article_id')\n",
    "\n",
    "# 简单数据预处理\n",
    "sparse_features = [\n",
    "    'user_id', 'article_id', 'net_status', 'exop_position', 'device', 'city', 'age', 'gender', 'cat_1'\n",
    "]\n",
    "dense_features = [\n",
    "    'img_num'\n",
    "]\n",
    "\n",
    "# 填充缺失值\n",
    "data_new[sparse_features] = data_new[sparse_features].fillna('-1')\n",
    "data_new[dense_features] = data_new[dense_features].fillna(0)\n",
    "\n",
    "# 归一化\n",
    "mms = MinMaxScaler(feature_range=(0, 1))\n",
    "data_new[dense_features] = mms.fit_transform(data_new[dense_features])\n",
    "\n",
    "feature_max_idx = {}\n",
    "for feat in sparse_features:\n",
    "    lbe = LabelEncoder()\n",
    "    data_new[feat] = lbe.fit_transform(data_new[feat])\n",
    "    feature_max_idx[feat] = data_new[feat].max() + 1000\n",
    "\n",
    "# 构建用户id词典和doc的id词典，方便从用户idx找到原始的id\n",
    "# user_id_enc = data[['user_id']].drop_duplicates('user_id')\n",
    "# doc_id_enc = data[['article_id']].drop_duplicates('article_id')\n",
    "# user_idx_2_rawid = dict(zip(user_id_enc['user_id'], user_id_raw['user_id']))\n",
    "# doc_idx_2_rawid = dict(zip(doc_id_enc['article_id'], doc_id_raw['article_id']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ad2691ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 划分数据集  这里按照曝光时间划分\n",
    "train_data = data_new[data_new['expo_time'] < '2021-07-06']\n",
    "test_data = data_new[data_new['expo_time'] >= '2021-07-06']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad199918",
   "metadata": {},
   "source": [
    "## 特征封装"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d0e5ea3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sparse_feature_columns = [SparseFeat(feat, feature_max_idx[feat], embedding_dim=4) for feat in sparse_features]\n",
    "Dense_feature_columns = [DenseFeat(feat, 1) for feat in dense_features]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c9f538ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 划分dnn和linear特征\n",
    "dnn_features_columns = sparse_feature_columns + Dense_feature_columns\n",
    "lhuc_feature_columns = sparse_feature_columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f5355501",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_names = get_feature_names(dnn_features_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c6a5f286",
   "metadata": {},
   "outputs": [],
   "source": [
    "# AttributeError: 'numpy.dtype[int64]' object has no attribute 'base_dtype' \n",
    "# Keras需要把输入声明为Keras张量，其他的比如numpy张量作为输入不好使\n",
    "train_model_input = {name: tf.keras.backend.constant(train_data[name]) for name in feature_names}\n",
    "test_model_input = {name: tf.keras.backend.constant(test_data[name]) for name in feature_names}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e4803b7",
   "metadata": {},
   "source": [
    "## 模型搭建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ab040dc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers import *\n",
    "from tensorflow.keras.models import *\n",
    "from tensorflow.python.keras.models import Model\n",
    "from tensorflow.python.keras.initializers import Zeros, glorot_normal\n",
    "from tensorflow.python.keras.regularizers import l2\n",
    "\n",
    "from collections import OrderedDict\n",
    "import itertools\n",
    "from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9527b790",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ShareBottom import DNN, build_input_layers, build_embedding_layers, concat_embedding_list, combined_dnn_input, concat_func, PredictionLayer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3f2e5bc",
   "metadata": {},
   "source": [
    "## Lhuc_net:\n",
    "lhuc_net: 语言识别领域的模型，核心思想是做说话人的自适应，其中一个关键突破是DNN网络中，为每个说话人学习一个特定的隐式单位贡献，提升不同说话人的语音效果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c0000c69",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lhuc_net(name, nn_inputs, lhuc_inputs, nn_hidden_units=(128, 64, ), lhuc_units=(32, ), \n",
    "             dnn_activation='relu', l2_reg_dnn=0, dnn_dropout=0, dnn_use_bn=False, scale_last=True, seed=2021):\n",
    "    \"\"\"这个网络是全连接网络搭建的，主要完成lhuc_feature与其他特征的交互， 算是一个特征交互层，不过交互的方式非常新颖\n",
    "    \n",
    "        name: 为当前lhuc_net起的名字\n",
    "        nn_inputs: 与lhuc_feature进行交互的特征输入，比如fm_out， 或者其他特征的embedding拼接等\n",
    "        lhuc_inputs: lhuc_net的特征输入，在推荐里面，这个其实是能体现用户个性化的一些特征embedding等\n",
    "        nn_hidden_units: 普通DNN每一层神经单元个数\n",
    "        lhuc_units: lhuc_net的神经单元个数\n",
    "        后面就是激活函数， 正则化以及bn的指定参数，不过多解释\n",
    "    \"\"\"\n",
    "    \n",
    "    # nn_inputs可以是其他特征的embedding拼接向量，或者是其他网络的输出，比如fM的输出向量等\n",
    "    cur_layer = nn_inputs       \n",
    "    \n",
    "    # 这里的nn_hidden_units是一个列表，里面是全连接每一层神经单元个数\n",
    "    for idx, nn_dim in enumerate(nn_hidden_units):\n",
    "        # lhuc_feature走一个塔， 这个塔两层， 最终输出的向量维度和nn_inputs的向量维度保持一致， 每个值在0-1之间，代表权重\n",
    "        # 表示fm_embedding或者其他特征embdding每个维度上的重要性  \n",
    "        # 这里其实可以用多层 激活函数用relu \n",
    "        lhuc_output = DNN(lhuc_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, \n",
    "                          seed=seed, name=\"{}_lhuc_{}\".format(name, idx))(lhuc_inputs)\n",
    "        # 最后这里的输出维度要和交互的embedding保持一致， 激活函数是sigmoid，\n",
    "        lhuc_scale = Dense(int(cur_layer.shape[1]), activation='sigmoid')(lhuc_output)\n",
    "        \n",
    "        # 有了权重之后， lhuc_scale与nn_inputs再过一个塔\n",
    "        cur_layer = DNN((nn_dim, ), dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, \n",
    "                        seed=seed, name=\"{}_layer_{}\".format(name, idx))(cur_layer * lhuc_scale * 2.0)\n",
    "        \n",
    "    # 上面这个操作相当于nn_input_embedding过了len(nn_hidden_units)层全连接， 只不过，在过每一层之前，会先lhuc_slot特征通过lhuc_net为\n",
    "    # nn_input_embedding过完全连接之后的每个维度学习权重，作为每个维度的重要性\n",
    "    # 如果最后的输出还需要加权，再走一遍上面的操作\n",
    "    if scale_last:\n",
    "        lhuc_output = DNN(lhuc_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, \n",
    "                          seed=seed, name=\"{}_lhuc_{}\".format(name, len(nn_hidden_units)))(lhuc_inputs)\n",
    "        lhuc_scale = Dense(int(cur_layer.shape[1]), activation='sigmoid')(lhuc_output)\n",
    "        \n",
    "        cur_layer = cur_layer * lhuc_scale * 2.0\n",
    "    \n",
    "    return cur_layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4845b132",
   "metadata": {},
   "source": [
    "上⾯我觉得关键是lhuc_slot⽤的特征是什么， 为什么这样可以对fm_embedding或者bias_embedding进⾏加权操作。 得需要搞明⽩这⾥⾯⽤的特征之间的制约性或者相关性。 \n",
    "* lhuc_feature: 主要是用户id， doc_id，doc_类别， doc_字数， doc_作者等拼接， 这些都是用户和item的强烈代表特征， 这个拼接的embedding代表的是用户对于item的兴趣偏好\n",
    "* 待交互的特征：\n",
    "    * bias_nn_inputs: 这里一般是原始的特征embedding拼接起来，代表特征的原始信息\n",
    "    * 其他模块输出，比如fm的输出: 这个是能产生交互的特征embedding，代表的是重要的特征交互信息\n",
    "\n",
    "所以，lhuc_net主要是在原始信息或者是特征交互信息过DNN的每一层，先对DNN每一层的输出的每个维度，根据用户对于item的兴趣偏好，进行加权，来提升每一层DNN输出的不同维度的贡献程度，来体现用户的个性化信息(相比于不加lhuc_net)，此外，还能进行降维。毕竟通过个性化进行了一波选择吗。有点像Fibinet那里的se模块，不过那个是对每个embedding进行加权筛选，而这里是对DNN输出(可以看成一个embedding)的每个维度进行加权。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8497eebe",
   "metadata": {},
   "source": [
    "所以我感觉这个lhuc_net的思路也是⾮常不错的， 相当于在原来的基础上，通过⽤⼾对于视频的兴趣 偏好，对embedding的各个维度进⾏加权，提升不同维度的贡献程度。相当于只提取了更加重要的⼀ 些维度信息。 既节省了计算量，⼜避免维度冗余。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b9e32b52",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BilinearInteraction(Layer):\n",
    "    def __init__(self, bilinear_type=\"interaction\", seed=2022, **kwargs):\n",
    "        super(BilinearInteraction, self).__init__(**kwargs)\n",
    "        self.bilinear_type = bilinear_type\n",
    "        self.seed = seed\n",
    "    def build(self, input_shape):\n",
    "        # input_shape: [None, field_num, embed_num]\n",
    "        self.field_size = input_shape[1]\n",
    "        self.embedding_size = input_shape[-1]\n",
    "        \n",
    "        if self.bilinear_type == 'all':  #所有embedding矩阵共用一个矩阵W\n",
    "            self.W = self.add_weight(shape=(self.embedding_size, self.embedding_size), \n",
    "                                     initializer=glorot_normal(seed=self.seed), name=\"bilinear_weight\")\n",
    "        elif self.bilinear_type == \"each\": # 每个field共用一个矩阵W\n",
    "            self.W_list = [self.add_weight(shape=(self.embedding_size, self.embedding_size), initializer=glorot_normal(\n",
    "                seed=self.seed), name=\"bilinear_weight\" + str(i)) for i in range(self.field_size-1)]\n",
    "        elif self.bilinear_type == \"interaction\":  # 每个交互用一个矩阵W\n",
    "            self.W_list = [self.add_weight(shape=(self.embedding_size, self.embedding_size), initializer=glorot_normal(\n",
    "                seed=self.seed), name=\"bilinear_weight\" + str(i) + '_' + str(j)) for i, j in\n",
    "                           itertools.combinations(range(self.field_size), 2)]\n",
    "        super(BilinearInteraction, self).build(input_shape)  # Be sure to call this somewhere!\n",
    "    \n",
    "    def call(self, inputs):\n",
    "        # inputs: [None, field_nums, embed_dims]\n",
    "        # 这里把inputs从field_nums处split, 划分成field_nums个embed_dims长向量的列表\n",
    "        inputs = tf.split(inputs, self.field_size, axis=1)  # [(None, embed_dims), (None, embed_dims), ..] \n",
    "        n = len(inputs)  # field_nums个\n",
    "        \n",
    "        if self.bilinear_type == \"all\":\n",
    "            # inputs[i] (none, embed_dims)    self.W (embed_dims, embed_dims) -> (None, embed_dims)\n",
    "            vidots = [tf.tensordot(inputs[i], self.W, axes=(-1, 0)) for i in range(n)]   # 点积\n",
    "            p = [tf.multiply(vidots[i], inputs[j]) for i, j in itertools.combinations(range(n), 2)]  # 哈达玛积\n",
    "        elif self.bilinear_type == \"each\":\n",
    "            vidots = [tf.tensordot(inputs[i], self.W_list[i], axes=(-1, 0)) for i in range(n - 1)]\n",
    "            # 假设3个域， 则两两组合[(0,1), (0,2), (1,2)]  这里的vidots是第一个维度， inputs是第二个维度 哈达玛积运算\n",
    "            p = [tf.multiply(vidots[i], inputs[j]) for i, j in itertools.combinations(range(n), 2)]\n",
    "        elif self.bilinear_type == \"interaction\":\n",
    "            # combinations(inputs, 2)  这个得到的是两两向量交互的结果列表\n",
    "            # 比如 combinations([[1,2], [3,4], [5,6]], 2)\n",
    "            # 得到 [([1, 2], [3, 4]), ([1, 2], [5, 6]), ([3, 4], [5, 6])]  (v[0], v[1]) 先v[0]与W点积，然后再和v[1]哈达玛积\n",
    "            p = [tf.multiply(tf.tensordot(v[0], w, axes=(-1, 0)), v[1])\n",
    "                 for v, w in zip(itertools.combinations(inputs, 2), self.W_list)]\n",
    "        \n",
    "        output = Concatenate(axis=1)(p)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "33dd687f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def SharedBottom(dnn_feature_columns, lhuc_feature_columns, bottom_dnn_hidden_units=(256, 128), tower_dnn_hidden_units=(64, ), \n",
    "                l2_reg_embedding=0.00001, l2_reg_dnn=0, seed=2021, dnn_dropout=0, dnn_activation='relu',\n",
    "                dnn_use_bn=False, task_types=('binary', 'binary'), task_names=('ctr', 'ctcvr'), bilinear_type='interaction'):\n",
    "    \n",
    "    num_tasks = len(task_names)\n",
    "    \n",
    "    # 异常判断\n",
    "    for task_type in task_types:\n",
    "        if task_type not in ['binary', 'regression']:\n",
    "            raise ValueError(\"task must be binary or regression, {} is illegal\".format(task_type))\n",
    "    \n",
    "    # 构建Input层并将Input层转成列表作为模型的输入\n",
    "    input_layer_dict = build_input_layers(dnn_feature_columns)\n",
    "    input_layers = list(input_layer_dict.values())\n",
    "    \n",
    "    # 筛选出特征中的sparse和Dense特征， 后面要单独处理\n",
    "    sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns))\n",
    "    dense_feature_columns = list(filter(lambda x: isinstance(x, DenseFeat), dnn_feature_columns))\n",
    "    \n",
    "    # 获取Dense Input\n",
    "    dnn_dense_input = []\n",
    "    for fc in dense_feature_columns:\n",
    "        dnn_dense_input.append(input_layer_dict[fc.name])\n",
    "    \n",
    "    # 构建embedding字典\n",
    "    embedding_layer_dict = build_embedding_layers(dnn_feature_columns)\n",
    "    # 离散的这些特特征embedding之后，然后拼接，然后直接作为全连接层Dense的输入，所以需要进行Flatten\n",
    "    dnn_sparse_embed_input = concat_embedding_list(sparse_feature_columns, input_layer_dict, embedding_layer_dict, flatten=False)\n",
    "    \n",
    "    # 把连续特征和离散特征合并起来\n",
    "    bias_input = combined_dnn_input(dnn_sparse_embed_input, dnn_dense_input)\n",
    "    \n",
    "    # 下面dnn_sparse_embed_input进行双线性交互\n",
    "    bilinear_out = BilinearInteraction(bilinear_type=bilinear_type)(Concatenate(axis=1)(dnn_sparse_embed_input))\n",
    "    \n",
    "    # lhuc_features_columns\n",
    "    lhuc_input = concat_embedding_list(lhuc_feature_columns, input_layer_dict, embedding_layer_dict, flatten=True)\n",
    "    lhuc_input = concat_func(lhuc_input)\n",
    "    \n",
    "    # bilinear_out与lhuc_input过lhuc_net\n",
    "    bilinear_out_flatt = Flatten()(bilinear_out)\n",
    "    bilinear_lhuc_out = lhuc_net(\"bilinear_lhuc\", bilinear_out_flatt, lhuc_input)\n",
    "    \n",
    "    # bias_input与lhuc_input过lhuc_net\n",
    "    bias_lhuc_out = lhuc_net(\"bias_lhuc\", bias_input, lhuc_input)\n",
    "    \n",
    "    # 两个输出拼接就是双线性net的最终输出结果，汇总了原始信息和交叉信息， 且通过lhuc_net对维度加权，在DNN每一层做一个维度筛选\n",
    "    sb_out = Concatenate(axis=-1)([bilinear_lhuc_out, bias_lhuc_out])\n",
    "    \n",
    "    # 每个任务独立的tower\n",
    "    task_outputs = []\n",
    "    for task_type, task_name in zip(task_types, task_names):\n",
    "        # 建立tower\n",
    "        tower_output = DNN(tower_dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, seed=2022, name='tower_'+task_name)(sb_out)\n",
    "        logit = Dense(1, use_bias=False, activation=None)(tower_output)\n",
    "        output = PredictionLayer(task_type, name=task_name)(logit)\n",
    "        task_outputs.append(output)\n",
    "    \n",
    "    model = Model(inputs=input_layers, outputs=task_outputs)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9b0270fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SharedBottom(dnn_features_columns, lhuc_feature_columns, tower_dnn_hidden_units=[], task_types=['regression', 'binary'], \n",
    "             task_names=['duration', 'click'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "8edfda88",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb7b39f2",
   "metadata": {},
   "source": [
    "## 模型的训练和预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "39172a8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(\"adam\", loss={\"duration\": \"mean_squared_error\", \"click\": \"binary_crossentropy\"}, \n",
    "              loss_weights={\"duration\": 0.02, \"click\": 0.98},\n",
    "              metrics={\"duration\": \"mae\", \"click\": \"binary_crossentropy\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "051a9186",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_duration = tf.keras.backend.constant(train_data['duration'].values)\n",
    "label_click = tf.keras.backend.constant(train_data['click'].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "c32ea007",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 124189 samples, validate on 31048 samples\n",
      "Epoch 1/10\n",
      "124189/124189 [==============================] - 33s 265us/sample - loss: 129.4515 - duration_loss: 6447.7607 - click_loss: 0.4120 - duration_mae: 35.6925 - click_binary_crossentropy: 0.4121 - val_loss: 123.5746 - val_duration_loss: 6150.1084 - val_click_loss: 0.4234 - val_duration_mae: 42.4280 - val_click_binary_crossentropy: 0.4235\n",
      "Epoch 2/10\n",
      "124189/124189 [==============================] - 26s 205us/sample - loss: 121.6239 - duration_loss: 6062.1685 - click_loss: 0.3736 - duration_mae: 34.3889 - click_binary_crossentropy: 0.3737 - val_loss: 134.1051 - val_duration_loss: 6676.2964 - val_click_loss: 0.4221 - val_duration_mae: 50.3868 - val_click_binary_crossentropy: 0.4222\n",
      "Epoch 3/10\n",
      "124189/124189 [==============================] - 26s 212us/sample - loss: 115.6162 - duration_loss: 5760.5908 - click_loss: 0.3616 - duration_mae: 32.8438 - click_binary_crossentropy: 0.3617 - val_loss: 152.3486 - val_duration_loss: 7586.0078 - val_click_loss: 0.4573 - val_duration_mae: 55.2336 - val_click_binary_crossentropy: 0.4575 duration_loss: 5748.2695 - click_loss: 0.3617 - duration_mae: 32.8089 - click_binary_crossentropy:\n",
      "Epoch 4/10\n",
      "124189/124189 [==============================] - 26s 207us/sample - loss: 110.7495 - duration_loss: 5519.1523 - click_loss: 0.3527 - duration_mae: 31.6416 - click_binary_crossentropy: 0.3528 - val_loss: 163.1362 - val_duration_loss: 8125.6582 - val_click_loss: 0.4373 - val_duration_mae: 57.0959 - val_click_binary_crossentropy: 0.4375\n",
      "Epoch 5/10\n",
      "124189/124189 [==============================] - 26s 207us/sample - loss: 106.4043 - duration_loss: 5304.7515 - click_loss: 0.3469 - duration_mae: 30.7003 - click_binary_crossentropy: 0.3468 - val_loss: 155.4251 - val_duration_loss: 7738.9189 - val_click_loss: 0.4656 - val_duration_mae: 52.5892 - val_click_binary_crossentropy: 0.4658\n",
      "Epoch 6/10\n",
      "124189/124189 [==============================] - 25s 205us/sample - loss: 102.2682 - duration_loss: 5093.8828 - click_loss: 0.3399 - duration_mae: 29.7638 - click_binary_crossentropy: 0.3400 - val_loss: 171.8720 - val_duration_loss: 8560.7080 - val_click_loss: 0.4716 - val_duration_mae: 57.8897 - val_click_binary_crossentropy: 0.4718\n",
      "Epoch 7/10\n",
      "124189/124189 [==============================] - 26s 206us/sample - loss: 97.6583 - duration_loss: 4865.2354 - click_loss: 0.3332 - duration_mae: 28.8503 - click_binary_crossentropy: 0.3332 - val_loss: 185.9917 - val_duration_loss: 9265.3018 - val_click_loss: 0.4786 - val_duration_mae: 58.4092 - val_click_binary_crossentropy: 0.4789\n",
      "Epoch 8/10\n",
      "124189/124189 [==============================] - 26s 208us/sample - loss: 93.7351 - duration_loss: 4668.0156 - click_loss: 0.3261 - duration_mae: 27.8894 - click_binary_crossentropy: 0.3261 - val_loss: 184.5322 - val_duration_loss: 9195.3936 - val_click_loss: 0.4249 - val_duration_mae: 57.3572 - val_click_binary_crossentropy: 0.4251\n",
      "Epoch 9/10\n",
      "124189/124189 [==============================] - 26s 209us/sample - loss: 89.5580 - duration_loss: 4459.2490 - click_loss: 0.3181 - duration_mae: 26.9729 - click_binary_crossentropy: 0.3182 - val_loss: 192.6629 - val_duration_loss: 9596.4922 - val_click_loss: 0.5015 - val_duration_mae: 54.8409 - val_click_binary_crossentropy: 0.5017\n",
      "Epoch 10/10\n",
      "124189/124189 [==============================] - 26s 210us/sample - loss: 85.8638 - duration_loss: 4276.8545 - click_loss: 0.3123 - duration_mae: 26.0333 - click_binary_crossentropy: 0.3122 - val_loss: 196.2682 - val_duration_loss: 9776.9883 - val_click_loss: 0.4954 - val_duration_mae: 56.1660 - val_click_binary_crossentropy: 0.4956\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(train_model_input, [label_duration, label_click],\n",
    "                        batch_size=128, epochs=10, verbose=1, validation_split=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "8d671e57",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_ans = model.predict(test_model_input, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "32dc4aa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test click AUC 0.6216\n"
     ]
    }
   ],
   "source": [
    "print(\"test click AUC\", round(roc_auc_score(test_data['click'], pred_ans[1]), 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "6c377e25",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test duration 52.0987\n"
     ]
    }
   ],
   "source": [
    "print(\"test duration\", round(mean_absolute_error(test_data['duration'], pred_ans[0]), 4))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
